{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "dc93fc27-8514-4a8d-8b0f-b784b31cf1fb",
   "metadata": {},
   "source": [
    "# Load the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "e430cf6f-941d-4cb5-88db-1e15345cbdab",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "data = pd.read_csv('Mental_Health_FAQ.csv')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01a56ff8-42b1-458a-9b1e-9e513e80f8ef",
   "metadata": {},
   "source": [
    "# Generate embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "57c536a8-0d82-42fe-887c-9391dfa898d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "question_emb_model = SentenceTransformer('thenlper/gte-base')\n",
    "\n",
    "data['question_emb'] = data['Questions'].apply(lambda x: question_emb_model.encode(x, normalize_embeddings=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "d38d3230-55a7-4c44-b40b-72c5a19a289c",
   "metadata": {},
   "outputs": [],
   "source": [
    "answer_emb_model = SentenceTransformer('BAAI/bge-large-en-v1.5')\n",
    "data['answer_emb'] = data['Answers'].apply(lambda x: answer_emb_model.encode(x, normalize_embeddings=True))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "84851c40-5224-412d-836e-bb8d9efc0a9e",
   "metadata": {},
   "source": [
    "# Index Documents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "c2141d2e-77c8-46fa-8ae4-a24b2af19640",
   "metadata": {},
   "outputs": [],
   "source": [
    "from elasticsearch import Elasticsearch\n",
    "\n",
    "from ssl import create_default_context\n",
    "\n",
    "context = create_default_context(cafile=r\"./http_ca.crt\")\n",
    "es = Elasticsearch('https://localhost:9200',\n",
    "    basic_auth=('elastic', 'YlGXk9PCN7AUlc*VMtQj'),\n",
    "    ssl_context=context,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "17b4444a-c065-4f75-bdcd-34a37022a862",
   "metadata": {},
   "outputs": [],
   "source": [
    "index_name=\"faq-index\"\n",
    "def generate_docs():\n",
    "    for index, row in data.iterrows():\n",
    "        doc = {\n",
    "                \"_index\": index_name,\n",
    "                \"_source\": {\n",
    "                    \"faq_id\":row['Question_ID'],\n",
    "                    \"question\":row['Questions'],\n",
    "                    \"answer\":row['Answers'],\n",
    "                    \"question_emb\": row['question_emb'],\n",
    "                    \"answer_emb\": row['answer_emb']\n",
    "                },\n",
    "            }\n",
    "\n",
    "        yield doc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "fe38445a-f025-44fc-92b9-70c0509de787",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "100%|████████████████████████████████████████████████████████| 98/98 [28:19<00:00, 17.34s/docs]\u001b[A\n",
      "\n",
      "  1%|▌                                                        | 1/98 [00:00<00:41,  2.32docs/s]\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Indexed 98/98 documents\n"
     ]
    }
   ],
   "source": [
    "import tqdm\n",
    "from elasticsearch.helpers import streaming_bulk\n",
    "number_of_docs=len(data)\n",
    "progress = tqdm.tqdm(unit=\"docs\", total=number_of_docs)\n",
    "successes = 0\n",
    "for ok, action in streaming_bulk(client=es, index=index_name, actions=generate_docs()):\n",
    "    progress.update(1)\n",
    "    successes += ok\n",
    "\n",
    "print(\"Indexed %d/%d documents\" % (successes, number_of_docs))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00c93af5-d99e-465c-9754-bd41641a21ba",
   "metadata": {},
   "source": [
    "# Query documents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "9e469476-dcac-42ea-962e-e870864d54d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "def faq_search(query=\"\", k=10, num_candidates=10):\n",
    "    \n",
    "    if query is not None and len(query) == 0:\n",
    "        print('Query cannot be empty')\n",
    "        return None\n",
    "    else:\n",
    "        query_question_emb = question_emb_model.encode(query, normalize_embeddings=True)\n",
    "\n",
    "        instruction=\"Represent this sentence for searching relevant passages: \"\n",
    "\n",
    "        query_answer_emb = answer_emb_model.encode(instruction + query, normalize_embeddings=True)\n",
    "\n",
    "        payload = {\n",
    "          \"query\": {\n",
    "            \"match\": {\n",
    "              \"title\": {\n",
    "                \"query\": query,\n",
    "                \"boost\": 0.2\n",
    "              }\n",
    "            }\n",
    "          },\n",
    "          \"knn\": [ {\n",
    "            \"field\": \"question_emb\",\n",
    "            \"query_vector\": query_question_emb,\n",
    "            \"k\": k,\n",
    "            \"num_candidates\": num_candidates,\n",
    "            \"boost\": 0.3\n",
    "          },\n",
    "          {\n",
    "            \"field\": \"answer_emb\",\n",
    "            \"query_vector\": query_answer_emb,\n",
    "            \"k\": k,\n",
    "            \"num_candidates\": num_candidates,\n",
    "            \"boost\": 0.5\n",
    "          }],\n",
    "          \"size\": 10,\n",
    "          \"_source\":[\"faq_id\",\"question\", \"answer\"]\n",
    "        }\n",
    "\n",
    "        response = es.search(index=index_name, body=payload)['hits']['hits']\n",
    "\n",
    "        return response"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f57f9bda-b4c2-4c19-b7f4-bc225f450899",
   "metadata": {},
   "source": [
    "# Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "e853459e-987a-487c-9725-212324561c38",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.\n",
      "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.\n",
      "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.\n",
      "/Users/liuxg/anaconda3/lib/python3.11/site-packages/transformers/models/auto/modeling_auto.py:1468: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n",
      "  warnings.warn(\n",
      "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can safely remove it from your `config.json` file.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelWithLMHead, AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"mrm8488/t5-small-finetuned-quora-for-paraphrasing\")\n",
    "model = AutoModelWithLMHead.from_pretrained(\"mrm8488/t5-small-finetuned-quora-for-paraphrasing\")\n",
    "\n",
    "def paraphrase(question, number_of_questions=3, max_length=128):\n",
    "    input_ids = tokenizer.encode(question, return_tensors=\"pt\", add_special_tokens=True)\n",
    "\n",
    "    generated_ids = model.generate(input_ids=input_ids, num_return_sequences=number_of_questions, num_beams=5, max_length=max_length, no_repeat_ngram_size=2, repetition_penalty=3.5, length_penalty=1.0, early_stopping=True)\n",
    "\n",
    "    preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated_ids]\n",
    "\n",
    "    return preds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "0670c668-4d08-4135-a91b-df06844e174d",
   "metadata": {},
   "outputs": [],
   "source": [
    "temp_data = data[['Question_ID','Questions']]\n",
    "\n",
    "eval_data = []\n",
    "\n",
    "for index, row in temp_data.iterrows():\n",
    "    preds = paraphrase(\"paraphrase: {}\".format(row['Questions']))\n",
    "    \n",
    "    for pred in preds:\n",
    "        temp={}\n",
    "        temp['Question'] = pred\n",
    "        temp['FAQ_ID'] = row['Question_ID']\n",
    "        eval_data.append(temp)\n",
    "    \n",
    "eval_data = pd.DataFrame(eval_data)\n",
    "\n",
    "#shuffle the evaluation dataset\n",
    "eval_data=eval_data.sample(frac=1).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "a030ffce-8e6e-46e7-bf3f-062a738d20ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_faq_id_s1(query=\"\", k=5, num_candidates=10):\n",
    "    \n",
    "    if query is not None and len(query) == 0:\n",
    "        print('Query cannot be empty')\n",
    "        return None\n",
    "    else:\n",
    "        instruction=\"Represent this sentence for searching relevant passages: \"\n",
    "\n",
    "        query_answer_emb = answer_emb_model.encode(instruction + query, normalize_embeddings=True)\n",
    "\n",
    "        payload = {\n",
    "          \"knn\": [\n",
    "          {\n",
    "            \"field\": \"answer_emb\",\n",
    "            \"query_vector\": query_answer_emb,\n",
    "            \"k\": k,\n",
    "            \"num_candidates\": num_candidates,\n",
    "          }],\n",
    "          \"size\": 1,\n",
    "          \"_source\":[\"faq_id\"]\n",
    "        }\n",
    "\n",
    "        response = es.search(index=index_name, body=payload)['hits']['hits']\n",
    "\n",
    "        return response[0]['_source']['faq_id']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "dccf12c6-954b-47a6-881f-5868c7dca8e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_faq_id_s2(query=\"\", k=5, num_candidates=10):\n",
    "    \n",
    "    if query is not None and len(query) == 0:\n",
    "        print('Query cannot be empty')\n",
    "        return None\n",
    "    else:\n",
    "        query_question_emb = question_emb_model.encode(query, normalize_embeddings=True)\n",
    "\n",
    "        instruction=\"Represent this sentence for searching relevant passages: \"\n",
    "\n",
    "        query_answer_emb = answer_emb_model.encode(instruction + query, normalize_embeddings=True)\n",
    "\n",
    "        payload = {\n",
    "          \"query\": {\n",
    "            \"match\": {\n",
    "              \"title\": {\n",
    "                \"query\": query,\n",
    "                \"boost\": 0.2\n",
    "              }\n",
    "            }\n",
    "          },\n",
    "          \"knn\": [ {\n",
    "            \"field\": \"question_emb\",\n",
    "            \"query_vector\": query_question_emb,\n",
    "            \"k\": k,\n",
    "            \"num_candidates\": num_candidates,\n",
    "            \"boost\": 0.3\n",
    "          },\n",
    "          {\n",
    "            \"field\": \"answer_emb\",\n",
    "            \"query_vector\": query_answer_emb,\n",
    "            \"k\": k,\n",
    "            \"num_candidates\": num_candidates,\n",
    "            \"boost\": 0.5\n",
    "          }],\n",
    "          \"size\": 1,\n",
    "          \"_source\":[\"faq_id\"]\n",
    "        }\n",
    "\n",
    "        response = es.search(index=index_name, body=payload)['hits']['hits']\n",
    "\n",
    "        return response[0]['_source']['faq_id']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "522cde46-e715-4a78-a830-b18f08161e73",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/kw/pvmj1zgs44l8t32cs8blf6sc0000gn/T/ipykernel_26974/710032138.py:23: DeprecationWarning: The 'body' parameter is deprecated and will be removed in a future version. Instead use individual parameters.\n",
      "  response = es.search(index=index_name, body=payload)['hits']['hits']\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "System 1 Accuracy: 0.7312925170068028\n"
     ]
    }
   ],
   "source": [
    "eval_data['PRED_FAQ_ID_S1'] = eval_data['Question'].apply(get_faq_id_s1)\n",
    "\n",
    "from sklearn.metrics import accuracy_score\n",
    "\n",
    "ground_truth = eval_data[\"FAQ_ID\"].values\n",
    "predictions_s1 = eval_data[\"PRED_FAQ_ID_S1\"].values\n",
    "\n",
    "s1_accuracy = accuracy_score(ground_truth, predictions_s1)\n",
    "\n",
    "print('System 1 Accuracy: {}'.format(s1_accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "4ce73526-eabb-436a-ae29-b76ba1ec5182",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/kw/pvmj1zgs44l8t32cs8blf6sc0000gn/T/ipykernel_26974/821299667.py:40: DeprecationWarning: The 'body' parameter is deprecated and will be removed in a future version. Instead use individual parameters.\n",
      "  response = es.search(index=index_name, body=payload)['hits']['hits']\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "System 2 Accuracy: 0.8401360544217688\n"
     ]
    }
   ],
   "source": [
    "eval_data['PRED_FAQ_ID_S2'] = eval_data['Question'].apply(get_faq_id_s2)\n",
    "\n",
    "predictions_s2 = eval_data[\"PRED_FAQ_ID_S2\"].values\n",
    "\n",
    "s2_accuracy = accuracy_score(ground_truth, predictions_s2)\n",
    "\n",
    "print('System 2 Accuracy: {}'.format(s2_accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7b9c8d51-6ffa-43bb-9452-642275a34a55",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python",
   "language": "python",
   "name": "py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
